#!/usr/bin/env python
# -*- coding: UTF-8 -*-

'''
数据加载
'''
from ImageLoader import ImageLoader
from LabelLoader import LabelLoader

def get_training_data_set():
    '''
    获得训练数据集
    '''
    image_loader = ImageLoader('./train-images.idx3-ubyte', 60000)
    label_loader = LabelLoader('./train-labels.idx1-ubyte', 60000)
    return image_loader.load(), label_loader.load()


def get_test_data_set():
    '''
    获得测试数据集
    '''
    image_loader = ImageLoader('t10k-images.idx3-ubyte', 10000)
    label_loader = LabelLoader('t10k-labels.idx1-ubyte', 10000)
    return image_loader.load(), label_loader.load()